# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0


from typing import Optional

import ttnn

import math


def find_closest_largest_divisor(num: int, start_divisor: int):
    """Return the largest divisor of num that is <= start_divisor.

    Used to choose a core count that divides a work quota. Assumes
    1 <= start_divisor <= num. Decrements until a divisor is found.
    """
    divisor = start_divisor
    while num % divisor != 0:
        divisor = divisor - 1
    return divisor


def _golden_function(input_tensor: ttnn.Tensor, dim: Optional[int] = None, **_):
    import torch

    dim = dim or -1

    return torch.nn.Softmax(dim)(input_tensor)


ttnn.attach_golden_function(
    ttnn.softmax,
    golden_function=_golden_function,
)

ttnn.attach_golden_function(
    ttnn.softmax_in_place,
    golden_function=_golden_function,
)


def _golden_function(input_tensor: ttnn.Tensor, scalar: float, attention_mask=None, **_):
    import torch

    input_tensor = input_tensor.float()
    input_tensor = input_tensor * scalar
    if attention_mask is not None:
        input_tensor = input_tensor + attention_mask
    return torch.softmax(input_tensor, dim=-1)


ttnn.attach_golden_function(
    ttnn.scale_mask_softmax_in_place,
    golden_function=_golden_function,
)

ttnn.attach_golden_function(
    ttnn.scale_mask_softmax,
    golden_function=_golden_function,
)

ttnn.attach_golden_function(
    ttnn.scale_causal_mask_hw_dims_softmax_in_place,
    golden_function=_golden_function,
)


SoftmaxProgramConfig = ttnn._ttnn.operations.normalization.SoftmaxProgramConfig
SoftmaxDefaultProgramConfig = ttnn._ttnn.operations.normalization.SoftmaxDefaultProgramConfig
SoftmaxShardedMultiCoreProgramConfig = ttnn._ttnn.operations.normalization.SoftmaxShardedMultiCoreProgramConfig


def _golden_function(
    input_tensor: ttnn.Tensor,
    *,
    epsilon=1e-12,
    residual_input_tensor=None,
    weight=None,
    bias=None,
    **_,
):
    import torch

    if residual_input_tensor is not None:
        input_tensor += residual_input_tensor

    if weight is not None:
        if len(weight.shape) >= 2:
            weight = weight.squeeze()
        weight = weight.to(input_tensor.dtype)

    if bias is not None:
        if len(bias.shape) >= 2:
            bias = bias.squeeze()
        bias = bias.to(input_tensor.dtype)

    return torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), weight, bias, eps=epsilon)


ttnn.attach_golden_function(ttnn.layer_norm, golden_function=_golden_function)


def _golden_function(input_tensor: ttnn.Tensor, weight=None, *, epsilon=1e-12, **_):
    import torch

    variance = input_tensor.to(torch.float32).pow(2).mean(-1, keepdim=True)
    input_tensor = input_tensor * torch.rsqrt(variance + epsilon)

    if weight.dtype in [torch.float16, torch.bfloat16]:
        input_tensor = input_tensor.to(weight.dtype)

    return weight * input_tensor


ttnn.attach_golden_function(ttnn.rms_norm, golden_function=_golden_function)

LayerNormProgramConfig = ttnn._ttnn.operations.normalization.LayerNormProgramConfig
LayerNormDefaultProgramConfig = ttnn._ttnn.operations.normalization.LayerNormDefaultProgramConfig
LayerNormShardedMultiCoreProgramConfig = ttnn._ttnn.operations.normalization.LayerNormShardedMultiCoreProgramConfig


# group norm helper function
def determine_expected_group_norm_sharded_config_and_grid_size(
    *, device, num_channels, num_groups, input_nhw, is_height_sharded, is_row_major=False
):
    """Derive sharded memory config and grid for group norm.

    - num_channels must be divisible by num_groups and 32 (tile width).
    - input_nhw is N*H*W in logical units; padded to core multiples.
    - If is_height_sharded: shard along NHW only; channels per core is all C.
      Otherwise: shard across channels and NHW (BLOCK_SHARDED).
    - is_row_major toggles shard shape orientation.

    Returns: (MemoryConfig, CoreGrid)
    """
    assert num_channels % num_groups == 0
    assert num_channels % 32 == 0  # TODO: remove this later
    group_size = num_channels // num_groups
    compute_with_storage_grid_size = device.compute_with_storage_grid_size()
    device_grid_size = [compute_with_storage_grid_size.x, compute_with_storage_grid_size.y]
    if is_row_major:
        device_grid_size = [compute_with_storage_grid_size.y, compute_with_storage_grid_size.x]

    max_num_cores = device_grid_size[0] * device_grid_size[1]
    input_nhw_paddedto32 = math.ceil(input_nhw / 32) * 32
    num_cores_nhw = find_closest_largest_divisor(
        input_nhw_paddedto32 // 32, max_num_cores if is_height_sharded else device_grid_size[0]
    )
    if is_height_sharded:
        num_cores_channels = 1
    else:
        num_cores_channels = device_grid_size[1]
        # num_channels_tiles = num_channels // 16
        num_channels_tiles = num_channels // 8
        while (num_channels_tiles % num_cores_channels != 0) or (
            ((num_channels // num_cores_channels) % group_size) != 0
        ):
            num_cores_channels -= 1
            assert num_cores_channels > 0
    input_nhw_padded_to_ncores = math.ceil(input_nhw / (num_cores_nhw * 32)) * (num_cores_nhw * 32)
    gn_in_channels_per_core = num_channels // num_cores_channels
    # assert gn_in_channels_per_core % 16 == 0
    assert gn_in_channels_per_core % 8 == 0
    gn_nhw_per_core = input_nhw_padded_to_ncores // num_cores_nhw
    if is_height_sharded:
        grid_size = [
            device_grid_size[0] if num_cores_nhw >= device_grid_size[0] else num_cores_nhw,
            math.ceil(num_cores_nhw / device_grid_size[0]),
        ]  # for 1d systolic array, grid size is the tightest bound of num_cores_nhw as a rectangle (x,y)
        assert (
            num_cores_nhw <= grid_size[0] * grid_size[1]
        ), "Error: For height sharding, num_cores_nhw must be <= grid size"
    else:
        grid_size = [num_cores_channels, num_cores_nhw] if is_row_major else [num_cores_nhw, num_cores_channels]
    shard_shape = (
        (1, 1, gn_nhw_per_core, gn_in_channels_per_core)
        if is_row_major
        else (1, 1, gn_in_channels_per_core, gn_nhw_per_core)
    )
    shard_strategy = ttnn.ShardStrategy.HEIGHT if is_height_sharded else ttnn.ShardStrategy.BLOCK
    shard_orientation = (
        ttnn.ShardOrientation.ROW_MAJOR if is_height_sharded or is_row_major else ttnn.ShardOrientation.COL_MAJOR
    )
    return ttnn.create_sharded_memory_config(
        shard_shape,
        ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]),
        shard_strategy,
        shard_orientation,
        use_height_and_width_as_shard_shape=True,
    ), ttnn.CoreGrid(y=grid_size[1], x=grid_size[0])


def create_group_norm_weight_bias_rm(input_tensor, num_channels, num_cores_x):
    """Prepares a gamma/beta tensor in a padded [1,1,-1,32] format.

    - Splits channels into num_cores_x equal chunks
    - Pads each chunk to a multiple of 32 (tile width).
    - Returns a tensor reshaped to [1, 1, tiles_per_core_total, 32].
    """
    import torch

    def find_ceil_divisible_by_32(n):
        return ((n + 31) // 32) * 32

    values_per_chunk = num_channels // num_cores_x
    zeros_to_insert = find_ceil_divisible_by_32(values_per_chunk) - values_per_chunk
    input_tensor = input_tensor.view(-1, values_per_chunk)
    input_tensor = torch.nn.functional.pad(input_tensor, (0, zeros_to_insert))
    input_tensor = input_tensor.flatten()
    input_tensor = input_tensor[: num_channels + zeros_to_insert * (num_channels // values_per_chunk)]
    return input_tensor.reshape(1, 1, -1, 32)


def dram_group_norm_virtual_columns(core_grid, num_channels, num_groups):
    """Choose number of virtual columns for DRAM params/mask generation.

    Tries to find the largest number of virtual columns that will evenly divide the number of channels into tiles.
    """
    num_virtual_cols = min(core_grid.x, num_groups)
    while (num_channels / num_virtual_cols) % ttnn.TILE_SIZE != 0:
        num_virtual_cols -= 1
    return num_virtual_cols


def dram_group_norm_params_from_torch(
    torch_params,
    channels_per_device,
    groups_per_device,
    device,
    mesh_axis=None,
    core_grid=None,
    return_mask=True,
    dtype=ttnn.bfloat16,
):
    """
    Create group norm parameters from torch in row major layout. It currently supports sharding along 1 mesh dimension. Sharding along 2 dimensions to be added as needed.
    Args:
        torch_params: List[torch.Tensor] or torch.Tensor. This is weith and or bias for the affine transformation.
        channels_per_device: Number of channels per device if using multi-device else number of channels
        groups_per_device: Number of groups per device if using multi-device else number of groups
        device: Device to create the group norm parameters on. Set to None if setting up on host. Must be provided if core_grid is None
        mesh_axis: Axis to shard the parameters on. Set to None if not sharding.
        core_grid: Core grid to use for the group norm parameters. Must be provided if device is None
        return_mask: Whether to return the mask.
        dtype: Data type to use for the group norm parameters.
    Returns:
        The prepared group norm parameters in the same order as torch_params. If return_mask is True, returns masks.
        Examples: [weight, bias], mask if return_mask is True, [weight, bias] if return_mask is False for inputs [torch_weight, torch_bias]
            Input: [torch_weight, torch_bias]   Output: [tt_weight, tt_bias], tt_mask if return_mask is True, [tt_weight, tt_bias] if return_mask is False
            Input: torch_weight                 Output: tt_weight, tt_mask if return_mask is True, tt_weight if return_mask is False
    """
    import torch

    assert core_grid or device, "Either core_grid or device must be provided to determin virtual columns"
    assert (
        channels_per_device % 32 == 0 == channels_per_device % groups_per_device
    ), f"channels_per_device {channels_per_device} must be divisible by 32 and groups_per_device {groups_per_device}"

    num_devices = 1
    mapper_dims = [None, None]
    if mesh_axis is not None:
        num_devices = tuple(device.shape)[mesh_axis]
        mapper_dims[mesh_axis] = 0  # shadding on channel dimension

    # Calculate number of virtual columns that will be used
    dev_core_grid = core_grid or device.core_grid
    num_virtual_cols = dram_group_norm_virtual_columns(dev_core_grid, channels_per_device, groups_per_device)
    tt_params = []
    torch_params_itr = [torch_params] if isinstance(torch_params, torch.Tensor) else torch_params

    # Create prepared device tensors for group norm
    for torch_param in torch_params_itr:
        computed_channels_per_device = torch_param.numel() // num_devices
        assert (
            computed_channels_per_device == channels_per_device
        ), f"Computed number of channels per device: {computed_channels_per_device} not equal to provided number of channels per device: {channels_per_device}"
        torch_sharded_lst = [
            ttnn.create_group_norm_weight_bias_rm(t, channels_per_device, num_virtual_cols)
            for t in torch_param.chunk(num_devices)
        ]
        tensor_to_shard = torch.cat(torch_sharded_lst, dim=0)

        tt_params.append(
            ttnn.from_torch(
                tensor_to_shard,
                dtype=dtype,
                device=device,
                mesh_mapper=ttnn.ShardTensor2dMesh(device, mesh_shape=tuple(device.shape), dims=mapper_dims),
            )
        )

    tt_params = tt_params[0] if isinstance(torch_params, torch.Tensor) else tt_params
    if return_mask:
        torch_mask = ttnn.create_group_norm_input_mask(channels_per_device, groups_per_device, num_virtual_cols)
        tt_mask = ttnn.from_torch(torch_mask, dtype=dtype, device=device, layout=ttnn.TILE_LAYOUT)
        return tt_params, tt_mask
    else:
        return tt_params


def find_max_tile_span(W, group_size, tile_width):
    """Finds the maximum (worst case) number of tiles a group of size group_size can span across.
    This helps in setting the mask width conservatively.
    """
    current_position = 0
    max_tile_span = 0

    while current_position < W:
        group_end = current_position + group_size
        start_tile = current_position // tile_width
        end_tile = (group_end - 1) // tile_width
        current_tile_span = end_tile - start_tile + 1
        max_tile_span = max(max_tile_span, current_tile_span)
        current_position = group_end
    return max_tile_span


def create_group_norm_mask_impl(num_channel, num_groups, num_cores_across_channel, is_negative_mask=False):
    """Create 4D mask [1, num_groups, 32, 32*block_wt] used by group norm.

    - block_wt is computed from worst-case tile span across groups.
    - num_cores_across_channel splits groups evenly across cores (must divide num_groups).
    """
    import torch

    block_wt = find_max_tile_span(num_channel, num_channel // num_groups, 32)
    if is_negative_mask == False:
        input_mask_tensor = torch.zeros((1, num_groups, 32, int(32 * block_wt)), dtype=torch.bfloat16)
    else:
        input_mask_tensor = torch.ones((1, num_groups, 32, int(32 * block_wt)), dtype=torch.bfloat16)

    num_groups_per_core = num_groups // num_cores_across_channel
    num_cols_per_group = num_channel // num_groups

    start_strides = []
    for _ in range(num_cores_across_channel):
        row_offset = 0
        start_strides.append(0)
        for _ in range(num_groups_per_core - 1):
            if row_offset + (num_cols_per_group % 32) == 32:
                row_offset = 0
            elif row_offset + (num_cols_per_group % 32) > 32:
                row_offset = (num_cols_per_group % 32) + row_offset - 32
            else:
                row_offset += num_cols_per_group % 32
            start_strides.append(row_offset)
        end_strides = [i + num_cols_per_group for i in start_strides]

    mask_val = 1 if is_negative_mask == False else 0
    for group in range(num_groups):
        start_stride = start_strides[group]
        end_stride = end_strides[group]
        end_stride = min(end_stride, input_mask_tensor.shape[3])
        input_mask_tensor[:, group, :, start_stride:end_stride] = mask_val

    return input_mask_tensor


def create_group_norm_reciprocals_impl(N, C, H, W, num_groups, core_grid):
    """
    Create reciprocals tensor for group norm with welford algorithm.
    Generates reciprocal values 1/1, 1/2, 1/3, ..., 1/N.
    The number of elements is based on the tensor size and the number of groups.
    The tensor is replicated for each core so that when sharded to L1 memory, each core has a complete copy.

    Args:
        N: Batch size
        C: Number of channels
        H: Height
        W: Width
        num_groups: Number of groups
        core_grid: Core grid

    Returns:
        Row major tensor with reciprocal values
    """
    import torch

    num_virtual_cols = dram_group_norm_virtual_columns(core_grid, C, num_groups)
    num_virtual_rows = (core_grid.x // num_virtual_cols) * core_grid.y

    # Calculate batch distribution
    num_virtual_rows_per_group = 1 if N >= num_virtual_rows else num_virtual_rows // N
    num_channels_per_group = C // num_groups
    num_height_tiles_per_group = math.ceil(H * W / ttnn.TILE_SIZE)

    num_reciprocals_per_group = num_channels_per_group * num_height_tiles_per_group
    num_reciprocals_per_core = num_reciprocals_per_group // num_virtual_rows_per_group

    # Create reciprocal values: 1/1, 1/2, 1/3, ..., 1/max_n
    reciprocals_tensor = 1.0 / torch.arange(1, num_reciprocals_per_core + 1, dtype=torch.float32)

    # Repeat the reciprocals tensor for each core so they all have identical copies
    return reciprocals_tensor.repeat(core_grid.x * core_grid.y, 1)


def create_group_norm_input_mask(num_channel, num_groups, num_cores_across_channel):  #
    return create_group_norm_mask_impl(num_channel, num_groups, num_cores_across_channel, is_negative_mask=False)


def create_group_norm_input_negative_mask(num_channel, num_groups, num_cores_across_channel):
    return create_group_norm_mask_impl(num_channel, num_groups, num_cores_across_channel, is_negative_mask=True)


def create_group_norm_reciprocals(N, C, H, W, num_groups, core_grid):
    return create_group_norm_reciprocals_impl(N, C, H, W, num_groups, core_grid)


def get_group_norm_cores_across_channel(memory_layout, core_grid):
    """Compute effective cores that split the channel axis.
    Used to reshape gamma/beta per-core views in the golden code.
    """
    if memory_layout == ttnn.types.TensorMemoryLayout.BLOCK_SHARDED:
        num_cores_across_channel = core_grid.y
    elif memory_layout == ttnn.types.TensorMemoryLayout.HEIGHT_SHARDED:
        num_cores_across_channel = 1
    else:
        num_cores_across_channel = core_grid.x * core_grid.y

    return num_cores_across_channel


def _golden_function(
    input_tensor: ttnn.Tensor,
    *,
    num_groups,
    epsilon=1e-05,
    weight=None,
    bias=None,
    memory_config=None,
    core_grid=None,
    input_mask=None,
    **kwargs,
):
    import torch

    num_channels = input_tensor.shape[-1]
    num_cores_across_channel = get_group_norm_cores_across_channel(memory_config.memory_layout, core_grid)
    weight = weight.reshape((num_cores_across_channel, -1))
    weight = weight[:, : num_channels // num_cores_across_channel].flatten()
    if bias is not None:
        bias = bias.reshape((num_cores_across_channel, -1))
        bias = bias[:, : num_channels // num_cores_across_channel].flatten()

    input_tensor = input_tensor.permute(0, 3, 1, 2)
    output = torch.nn.functional.group_norm(input_tensor.float(), num_groups, weight.float(), bias.float(), eps=epsilon)
    output = output.permute(0, 2, 3, 1)
    return output


def _postprocess_golden_function_outputs(output, args, kwargs):
    input_tensor = args[0]
    output = ttnn.reshape(output, input_tensor.shape)
    return output


ttnn.attach_golden_function(
    ttnn.group_norm,
    golden_function=_golden_function,
    postprocess_golden_function_outputs=_postprocess_golden_function_outputs,
)

__all__ = []
